【源头活水】一文教你彻底理解Google MLP-Mixer(附代码)
“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。
地址:https://zhuanlan.zhihu.com/p/372692759
随着深度神经网络发展至今,网络结构优化的瓶颈也慢慢显现出来。由此,文艺复兴随之出现,MLPs(Multi-layer Perceptrons)这种古老的结构也开始被重新拉上舞台。本文深入浅出介绍Google新坑,MLP-Mixer。
参考代码地址:https://github.com/lucidrains/mlp-mixer-pytorch
Google最近又挖了一个新坑,MLP-Mixer。原文提到,CNN以及self-attention这种相对复杂的网络结构在视觉任务上已经取得很好的表现了,但是我们真的需要这么复杂的网络结构吗?MLP这种简单的结构是否也能够取得SOTA的表现呢?MLP-Mixer给出了答案。
convolutions and attention are both sufficient for good performance, neither of them are necessary.--引自原文
01
MLP好理解,这个网络结构没有采用convolution以及attention的网络结构,纯粹使用MLP作为主要架构。
那为什么叫Mixer呢?举个例子就明白了,现在很多视觉任务的网络架构,其实就是mix不同的特征,找出各个特征之间的关系来获取有用的信息。从CNN的网络结构来理解就很简单了,拿一个NxNxC的kernel来举例,
(1)NxN这两个维度其实就是来mix不同位置像素点的mixer
(2)而C这个维度则是来mix一个像素点不同通道特征的mixer。
02
def MLPMixer(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):
assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
num_patches = (image_size // patch_size) ** 2
chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
return nn.Sequential(
# 1. 将图片拆成多个patches
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
# 2. 用一个全连接网络对所有patch进行处理,提取出tokens
nn.Linear((patch_size ** 2) * 3, dim),
# 3. 经过N个Mixer层,混合提炼特征信息
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim),
Reduce('b n c -> b c', 'mean'),
# 4. 最后一个全连接层进行类别预测
nn.Linear(dim, num_classes)
)
chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
return nn.Sequential(
dense(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
dense(dim * expansion_factor, dim),
nn.Dropout(dropout)
)
03
04
参考资料
https://arxiv.org/abs/2105.01601
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
“源头活水”历史文章
FaceBoxes阅读笔记
解析神经架构搜索(NAS)中权重共享的影响
CVPR2021 | 密集连接网络中的稀疏特征重激活
MLP-Mixer 里隐藏的卷积
深度学习结合传统几何的视觉定位方法:HSCNet简介
CVPR 2021 | 帮你理解域迁移:可视化网络知识的变化
视觉Transformer中的位置嵌入
多任务权重自动学习论文介绍和代码实现
Covariate Shift: 基于机器学习分类器的回顾和分析
NAS中基于MCT的搜索空间采样方法
LSNet: Anchor-free新玩法,只用一个head统一目标检测,实例分割,姿态估计三种任务
CV+Transformer之Swin Transformer
爆火的 Swin Transformer 到底做对了什么
mBART:多语言翻译预训练模型
更多源头活水专栏文章,
请点击文章底部“阅读原文”查看
分享、在看,给个三连击呗!